{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lf7huAiYp-An"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "YHz2D-oIqBWa"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x44FFES-r6y0"
      },
      "source": [
        "# Working with tff's ClientData."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/working_with_client_data\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/working_with_client_data.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/tutorials/working_with_client_data.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/working_with_client_data.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8RVecD0EfXdb"
      },
      "source": [
        "The notion of a dataset keyed by clients (e.g. users) is essential to federated computation as modeled in TFF. TFF provides the interface [`tff.simulation.datasets.ClientData`](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/ClientData) to abstract over this concept, and the datasets which TFF hosts ([stackoverflow](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/stackoverflow), [shakespeare](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare), [emnist](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist), [cifar100](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/cifar100), and [gldv2](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/gldv2)) all implement this interface.\n",
        "\n",
        "If you are working on federated learning with your own dataset, TFF strongly encourages you to either implement the `ClientData` interface or use one of TFF's helper functions to generate a `ClientData` which represents your data on disk, e.g. [`tff.simulation.datasets.ClientData.from_clients_and_fn`](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/ClientData?version=nightly#from_clients_and_fn).\n",
        "\n",
        "As most of TFF's end-to-end examples start with `ClientData` objects, implementing the `ClientData` interface with your custom dataset will make it easier to spelunk through existing code written with TFF. Further, the `tf.data.Datasets` which `ClientData` constructs can be iterated over directly to yield structures of `numpy` arrays, so `ClientData` objects can be used with any Python-based ML framework before moving to TFF.\n",
        "\n",
        "There are several patterns with which you can make your life easier if you intend to scale up your simulations to many machines or deploy them. Below we will walk through a few of the ways we can use `ClientData` and TFF to make our small-scale iteration-to large-scale experimentation-to production deployment experience as smooth as possible."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "snsz06ESrGvL"
      },
      "source": [
        "## Which pattern should I use to pass ClientData into TFF?\n",
        "\n",
        "We will discuss two usages of TFF's `ClientData` in depth; if you fit in either of the two categories below, you will clearly prefer one over the other. If not, you may need a more detailed understanding of the pros and cons of each to make a more nuanced choice.\n",
        "* I want to iterate as quickly as possible on a local machine; I don't need to be able to easily take advantage of TFF's distributed runtime.\n",
        " * You want to pass `tf.data.Datasets` in to TFF directly.\n",
        " * This allows you to program imperatively with `tf.data.Dataset` objects, and process them arbitrarily.\n",
        " * It provides more flexibility than the option below; pushing logic to the clients requires that this logic be serializable.\n",
        "\n",
        "* I want to run my federated computation in TFF's remote runtime, or I plan to do so soon.\n",
        " * In this case you want to map dataset construction and preprocessing to clients.\n",
        " * This results in you passing simply a list of `client_ids` directly to your federated computation.\n",
        " * Pushing dataset construction and preprocessing to the clients avoids bottlenecks in serialization, and significantly increases performance with hundreds-to-thousands of clients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KoCHeay4Rozd"
      },
      "outputs": [],
      "source": [
        "#@title Set up open-source environment\n",
        "#@test {\"skip\": true}\n",
        "\n",
        "# tensorflow_federated_nightly also bring in tf_nightly, which\n",
        "# can causes a duplicate tensorboard install, leading to errors.\n",
        "!pip uninstall --yes tensorboard tb-nightly\n",
        "\n",
        "!pip install --quiet --upgrade tensorflow_federated_nightly\n",
        "!pip install --quiet --upgrade nest_asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LNduVQsPNoH7"
      },
      "outputs": [],
      "source": [
        "#@title Import packages\r\n",
        "import collections\r\n",
        "import time\r\n",
        "\r\n",
        "import tensorflow as tf\r\n",
        "import tensorflow_federated as tff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dNOfCerkfZh_"
      },
      "source": [
        "## Manipulating a ClientData object\n",
        "\n",
        "Let's begin by loading and exploring TFF's EMNIST `ClientData`:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rd8vaOOfbe5X"
      },
      "outputs": [],
      "source": [
        "client_data, _ = tff.simulation.datasets.emnist.load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-46eXnKbmYP"
      },
      "source": [
        "Inspecting the first dataset can tell us what type of examples are in the `ClientData`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N1JvJvDkbxDo"
      },
      "outputs": [],
      "source": [
        "first_client_id = client_data.client_ids[0]\n",
        "first_client_dataset = client_data.create_tf_dataset_for_client(\n",
        "    first_client_id)\n",
        "print(first_client_dataset.element_spec)\n",
        "# This information is also available as a `ClientData` property:\n",
        "assert client_data.element_type_structure == first_client_dataset.element_spec"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Z8l3uuYv8cD"
      },
      "source": [
        "Note that the dataset yields `collections.OrderedDict` objects that have `pixels` and `label` keys, where pixels is a tensor with shape `[28, 28]`. Suppose we wish to flatten our inputs out to shape `[784]`. One possible way we can do this would be to apply a pre-processing function to our `ClientData` object."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VyPqaw6Uv7Fu"
      },
      "outputs": [],
      "source": [
        "def preprocess_dataset(dataset):\n",
        "  \"\"\"Create batches of 5 examples, and limit to 3 batches.\"\"\"\n",
        "\n",
        "  def map_fn(input):\n",
        "    return collections.OrderedDict(\n",
        "        x=tf.reshape(input['pixels'], shape=(-1, 784)),\n",
        "        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),\n",
        "    )\n",
        "\n",
        "  return dataset.batch(5).map(\n",
        "      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)\n",
        "\n",
        "\n",
        "preprocessed_client_data = client_data.preprocess(preprocess_dataset)\n",
        "\n",
        "# Notice that we have both reshaped and renamed the elements of the ordered dict.\n",
        "first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(\n",
        "    first_client_id)\n",
        "print(first_client_dataset.element_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NtpLRgdpl9Js"
      },
      "source": [
        "We may want in addition to perform some more complex (and possibly stateful) preprocessing, for example shuffling."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CtBVHcAmmKiu"
      },
      "outputs": [],
      "source": [
        "def preprocess_and_shuffle(dataset):\n",
        "  \"\"\"Applies `preprocess_dataset` above and shuffles the result.\"\"\"\n",
        "  preprocessed = preprocess_dataset(dataset)\n",
        "  return preprocessed.shuffle(buffer_size=5)\n",
        "\n",
        "preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)\n",
        "\n",
        "# The type signature will remain the same, but the batches will be shuffled.\n",
        "first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(\n",
        "    first_client_id)\n",
        "print(first_client_dataset.element_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ek7W3ZZHMr1k"
      },
      "source": [
        "## Interfacing with a `tff.Computation`\n",
        "\n",
        "Now that we can perform some basic manipulations with `ClientData` objects, we are ready to feed data to a `tff.Computation`. We define a [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess) which implements [Federated Averaging](https://arxiv.org/abs/1602.05629), and explore different methods of passing it data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j41nKFYse8GC"
      },
      "outputs": [],
      "source": [
        "def model_fn():\n",
        "  model = tf.keras.models.Sequential([\n",
        "      tf.keras.layers.InputLayer(input_shape=(784,)),\n",
        "      tf.keras.layers.Dense(10, kernel_initializer='zeros'),\n",
        "  ])\n",
        "  return tff.learning.from_keras_model(\n",
        "      model,\n",
        "      # Note: input spec is the _batched_ shape, and includes the \n",
        "      # label tensor which will be passed to the loss function. This model is\n",
        "      # therefore configured to accept data _after_ it has been preprocessed.\n",
        "      input_spec=collections.OrderedDict(\n",
        "          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n",
        "          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
        "  \n",
        "trainer = tff.learning.build_federated_averaging_process(\n",
        "    model_fn,\n",
        "    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ICJdME7-5lMx"
      },
      "source": [
        "Before we begin working with this `IterativeProcess`, one comment on the semantics of `ClientData` is in order. A `ClientData` object represents the *entirety* of the population available for federated training, which in general is [not available to the execution environment of a production FL system](https://arxiv.org/abs/1902.01046) and is specific to simulation. `ClientData` indeed gives the user the capacity to bypass federated computing entirely and simply train a server-side model as usual via [`ClientData.create_tf_dataset_from_all_clients`](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/ClientData?hl=en\u0026version=nightly#create_tf_dataset_from_all_clients).\n",
        "\n",
        "TFF's simulation environment puts the researcher in complete control of the outer loop. In particular this implies considerations of client availability, client dropout, etc, must be addressed by the user or Python driver script. One could for example model client dropout by adjusting the sampling distribution over your `ClientData's` `client_ids` such that users with more data (and correspondingly longer-running local computations) would be selected with lower probability.\n",
        "\n",
        "In a real federated system, however, clients cannot be selected explicitly by the model trainer; the selection of clients is delegated to the system which is executing the federated computation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zaoo661LOaCK"
      },
      "source": [
        "### Passing `tf.data.Datasets` directly to TFF\n",
        "\n",
        "One option we have for interfacing between a `ClientData` and an `IterativeProcess` is that of constructing `tf.data.Datasets` in Python, and passing these datasets to TFF.\n",
        "\n",
        "Notice that if we use our preprocessed `ClientData` the datasets we yield are of the appropriate type expected by our model defined above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U3R4cvZvPmxt"
      },
      "outputs": [],
      "source": [
        "selected_client_ids = preprocessed_and_shuffled.client_ids[:10]\n",
        "\n",
        "preprocessed_data_for_clients = [\n",
        "    preprocessed_and_shuffled.create_tf_dataset_for_client(\n",
        "        selected_client_ids[i]) for i in range(10)\n",
        "]\n",
        "\n",
        "state = trainer.initialize()\n",
        "for _ in range(5):\n",
        "  t1 = time.time()\n",
        "  state, metrics = trainer.next(state, preprocessed_data_for_clients)\n",
        "  t2 = time.time()\n",
        "  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XFaFlB59nAVi"
      },
      "source": [
        "If we take this route, however, we will be ***unable to trivially move to multimachine simulation***. The datasets we construct in the local TensorFlow runtime can *capture state from the surrounding python environment*, and fail in serialization or deserialization when they attempt to reference state which is no longer available to them. This can manifest for example in the inscrutable error from TensorFlow's `tensor_util.cc`:\n",
        "```\n",
        "Check failed: DT_VARIANT == input.dtype() (21 vs. 20)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q5VKu7OLny5X"
      },
      "source": [
        "### Mapping construction and preprocessing over the clients\n",
        "\n",
        "To avoid this issue, TFF recommends its users to consider dataset instantiation and preprocessing as *something that happens locally on each client*, and to use TFF's helpers or `federated_map` to explicitly run this preprocessing code at each client.\n",
        "\n",
        "Conceptually, the reason for preferring this is clear: in TFF's local runtime, the clients only \"accidentally\" have access to the global Python environment due to the fact that the entire federated orchestration is happening on a single machine. It is worthwhile noting at this point that similar thinking gives rise to TFF's cross-platform, always-serializable, functional philosophy.\n",
        "\n",
        "TFF makes such a change simple via `ClientData's` attribute `dataset_computation`, a `tff.Computation` which takes a `client_id` and returns the associated `tf.data.Dataset`.\n",
        "\n",
        "Note that `preprocess` simply works with `dataset_computation`; the `dataset_computation` attribute of the preprocessed `ClientData` incorporates the entire preprocessing pipeline we just defined:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yKiTjDj3pw4R"
      },
      "outputs": [],
      "source": [
        "print('dataset computation without preprocessing:')\n",
        "print(client_data.dataset_computation.type_signature)\n",
        "print('\\n')\n",
        "print('dataset computation with preprocessing:')\n",
        "print(preprocessed_and_shuffled.dataset_computation.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oGcSqAjuqJau"
      },
      "source": [
        "We could invoke `dataset_computation` and receive an eager dataset in the Python runtime, but the real power of this approach is exercised when we compose with an iterative process or another computation to avoid materializing these datasets in the global eager runtime at all. TFF provides a helper function [`tff.simulation.compose_dataset_computation_with_iterative_process`](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/compose_dataset_computation_with_iterative_process) which can be used to do exactly this."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "69vY85cmPsel"
      },
      "outputs": [],
      "source": [
        "trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(\n",
        "    preprocessed_and_shuffled.dataset_computation, trainer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ixrmztq6SbRE"
      },
      "source": [
        "Both this `tff.templates.IterativeProcesses` and the one above run the same way; but former accepts preprocessed client datasets, and the latter accepts strings representing client ids, handling both dataset construction and preprocessing in its body--in fact `state` can be passed between the two."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZcYPQxqlSapn"
      },
      "outputs": [],
      "source": [
        "for _ in range(5):\n",
        "  t1 = time.time()\n",
        "  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)\n",
        "  t2 = time.time()\n",
        "  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SeoQzU-5XeGz"
      },
      "source": [
        "### Scaling to large numbers of clients\n",
        "\n",
        "`trainer_accepting_ids` can immediately be used in TFF's multimachine runtime, and avoids materializing `tf.data.Datasets` and the controller (and therefore serializing them and sending them out to the workers).\n",
        "\n",
        "This significantly speeds up distributed simulations, especially with a large number of clients, and enables intermediate aggregation to avoid similar serialization/deserialization overhead.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iSy1t2UZQWCy"
      },
      "source": [
        "### Optional deepdive: manually composing preprocessing logic in TFF\n",
        "\n",
        "TFF is designed for compositionality from the ground up; the kind of composition just performed by TFF's helper is fully within our control as users. We could have manually compose the preprocessing computation we just defined with the trainer's own `next` quite simply:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yasFmYyIwTKY"
      },
      "outputs": [],
      "source": [
        "selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)\n",
        "\n",
        "@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)\n",
        "def new_next(server_state, selected_clients):\n",
        "  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)\n",
        "  return trainer.next(server_state, preprocessed_data)\n",
        "\n",
        "manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pHG0NXbWQuk7"
      },
      "source": [
        "In fact, this is effectively what the helper we used is doing under the hood (plus performing appropriate type checking and manipulation). We could even have expressed the same logic slightly differently, by serializing `preprocess_and_shuffle` into a `tff.Computation`, and decomposing the `federated_map` into one step which constructs un-preprocessed datasets and another which runs `preprocess_and_shuffle` at each client.\n",
        "\n",
        " We can verify that this more-manual path results in computations with the same type signature as TFF's helper (modulo parameter names):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C2sc5HkLPwkp"
      },
      "outputs": [],
      "source": [
        "print(trainer_accepting_ids.next.type_signature)\n",
        "print(manual_trainer_with_preprocessing.next.type_signature)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "working_with_client_data.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
